package socks5

import (
	"errors"
	"fmt"
	"io"
	"log"
	"net"
)

var (
	// ErrorVersionNotSupported 不支持的协议版本
	ErrorVersionNotSupported = errors.New("protocal version not supported")

	// ErrorCommandNotSupported 不支持的指令
	ErrorCommandNotSupported = errors.New("request command not supported")

	// ErrorReservedFieldInvalid 不支持的保留字段
	ErrorReservedFieldInvalid = errors.New("invalid reserved field")

	// ErrorAddressTypeNotSupported 不支持的地址类型
	ErrorAddressTypeNotSupported = errors.New("address type not supported")
)

const (
	SOCKS5_VERSION = 0x05
	REVERSED_FIELD = 0x00
)

type Server interface {
	Run() error
}

type Socks5Server struct {
	IP   string
	Port int
}

func (s *Socks5Server) Run() error {
	address := fmt.Sprintf("%s:%d", s.IP, s.Port)
	listener, err := net.Listen("tcp", address)
	if err != nil {
		return err
	}

	for {
		conn, err := listener.Accept()
		if err != nil {
			if conn == nil {
				log.Printf("error accepting connection: %v", err)
				continue
			}
			log.Printf("error accepting connection %v: %v", conn.RemoteAddr(), err)
			continue
		}

		log.Printf("new accept connection %v", conn.RemoteAddr())

		go func() {
			defer conn.Close()
			err := handleConnection(conn)
			if err != nil {
				log.Printf("handle connection failer: %v: %v", conn.RemoteAddr(), err)
			}
			log.Printf("--- connection done: %v", conn.RemoteAddr())
		}()
	}
}

func handleConnection(conn net.Conn) error {
	// 协商过程
	if err := auth(conn); err != nil {
		return err
	}

	// 请求过程
	targetConn, err := request(conn)
	if err != nil {
		return err
	}

	// 转发过程
	return forward(conn, targetConn)
}

// 转发
func forward(conn io.ReadWriter, targetConn io.ReadWriteCloser) error {
	if targetConn == nil {
		return errors.New("targetConn is nil")
	}
	defer targetConn.Close()

	go io.Copy(targetConn, conn)
	_, err := io.Copy(conn, targetConn)
	return err

	// var readErr = false
	// var writeErr = false
	// var err error

	// if targetConn != nil {
	// 	defer targetConn.Close()
	// } else {
	// 	return errors.New("targetConn is nil")
	// }
	// for {
	// 	if !readErr && conn != nil {
	// 		go func() {
	// 			_, err = io.Copy(targetConn, conn)
	// 			if err != nil {
	// 				readErr = true
	// 			}
	// 		}()
	// 	}

	// 	if !writeErr && targetConn != nil {
	// 		_, err = io.Copy(conn, targetConn)
	// 		if err != nil {
	// 			writeErr = true
	// 		}
	// 	}
	// 	if readErr && writeErr {
	// 		return err
	// 	}
	// }
}

// 请求
func request(conn io.ReadWriter) (io.ReadWriteCloser, error) {
	log.Printf("new request: %v", conn.(net.Conn).RemoteAddr())
	message, err := NewClientRequestMessage(conn)
	if err != nil {
		return nil, err
	}

	// 检查 Command 是否支持
	if message.Cmd != CmdConnect {
		// 返回 command 不支持
		return nil, WriteRequestFailureMessage(conn, ReplyCommandNotSupported)
	}

	if message.AddrType == TypeIPv6 {
		// 返回 AddressType 不支持
		return nil, WriteRequestFailureMessage(conn, ReplyAddressTypeNotSupported)
	}

	// 请求访问目标 TCP 服务
	address := fmt.Sprintf("%s:%d", message.Address, message.Port)
	targetConn, err := net.Dial("tcp", address)
	if err != nil {
		return nil, WriteRequestFailureMessage(conn, ReplyConnectionRefused)
	}
	log.Printf(">>> request: %v -> %v", conn.(net.Conn).RemoteAddr(), targetConn.RemoteAddr())

	// Send successfully reply
	addrValue := targetConn.LocalAddr()
	addr := addrValue.(*net.TCPAddr)

	return targetConn, WriteRequestSuccessMessage(conn, addr.IP, uint16(addr.Port))
}

// 协商
func auth(conn net.Conn) error {
	log.Printf("new auth %v", conn.RemoteAddr())
	clientMessage, err := NewClientAuthMessage(conn)
	if err != nil {
		return err
	}

	// log.Println(ClientAuthMessage.Version, ClientAuthMessage.NMethods, ClientAuthMessage.Methods)
	// Only support no-auth
	var acceptable bool = false
	for _, method := range clientMessage.Methods {
		if method == MethodNoAuth {
			acceptable = true
		}
	}

	if !acceptable {
		NewServerAuthMessage(conn, MethodNoAcceptable)
		return errors.New("method not supported")
	}
	return NewServerAuthMessage(conn, MethodNoAuth)
}
